基本信息
标题:Learning to propagate labels:Transductive propagation network for few-shot learning
年份:2019
- 期刊:ICLR
- 标签:transductive inference, few-shot learning
- 数据:miniImageNet,tieredImageNet
创新点
- 第一个提出用于Few-shot Learning的转导推理(transductive inference)
- 转导推理中,我们建议通过情景元学习(episodic meta-learning)来学习在不可见类的数据实例之间传播标签,这比基于启发式的标签传播算法表现要好得多
- 对于Few-shot Learning,我们的方法在miniImageNet,tieredImageNet数据集上均达到了最优性能。并且,对于Semi-supervised learning,我们的方法也超过了其它的Semi-supervised Few-shot Learning baselines
创新点来源
小样本学习(few-shot learning)的目标是在已知类别(Seen Class)训练一个分类模型,使它能够在只有少量数据的未知类别(Unseen Class)上面具有很好的泛化性能。小样本学习面临两个重要的问题:(1)已知类别和未知类别之间没有交集,导致它们的数据分布差别很大,不能直接通过训练分类器和微调(finetune)的方式得到很好的性能;(2)未知类别只有极少量数据(每个类别1或者5个训练样本),导致分类器学习不可靠。
对于第一个问题,Matching Networks提出了episodic training的策略。一个episode就是一个小样本学习的子任务,包含训练集和测试集。这里的episode类似于深度学习中的mini-batch的概念。
对于第二个问题,目前解决方法较少。我们提出利用转导(Transductive)的思想,拿到所有无标注数据,建立权重图,得到整个测试集的全部预测结果,如下图:
主要过程
问题定义
给定一个类别数相对较多的标记数据集,类别集合为${\cal C}_{train}$,目标是训练一个分类器,能够对只有少量标记样本的不可见的新类进行分类,新类的类别集合为${\cal C}_{test}$。
特别的,在每一个episoid中,从${\cal C}_{train}$中采样$N$类去构造支撑集和查询集。支撑集中,每一个类别包含$K$样本(即N-way K-shot),表示为${\cal S}=\{({\rm x}_1,y_1),({\rm x}_2,y_2), \cdots, ({\rm x}_{N \times K},y_{N \times K}) \}$,查询集集合为${\cal Q}=\{({\rm x}_1^,y_1^),({\rm x}_2^,y_2^), \cdots, ({\rm x}_T^,y_T^)$,包含来着这$N$类的不同样本。每一个episoid中,支撑集$\cal S$作为标记训练集,模型训练时,对查询集$\cal Q$进行预测,并使得损失最小。然后一个episode一个episode地训练,直到收敛。
通过episode训练实现的元学习在Few-shot Learning任务中表现良好。然而,因为支撑集中标记实例很匮乏($K$通常很小),我们观察到一个可靠的分类器仍然很难获得。这激发我们设计传导,利用整个查询集去预测,而不是独立的去预测每一个样本。将整个查询集考虑进来,可以减轻low-data问题,产生更加可靠的泛化能力。
TPN
如图所示,Transductive Propagation Network(TPN)包含四个部分:
- feature embedding :由CNN组成
- graph construction:生成example-wise参数以利用流形结构
- label propagation:将标签信息从支撑集$\cal S$扩散到查询集$\cal Q$
- loss:计算$\cal Q$中传播出的标签与真实标签之间的交叉熵,对框架中的所有参数进行优化
feature embedding
输入${\rm x}_i$到CNN$f_{\varphi}$中提取特征,$f_{\varphi}({\rm x}_i, \varphi)$表示特征图,$\varphi$表示网络中的参数。网络由四个卷积blocks组成。在每个block中,首先是一个$3 \times 3,filter=64$的卷积层,接着是一个batch-normalization层,Relu非线性层,$2\times 2$的最大池化层。我们为支撑集$\cal S$和查询集$\cal Q$应用相同的$f_{\varphi}$。
graph construction
这里使用的是高斯相似函数:
其中$d(\cdot,\cdot)$是距离度量(例如欧式距离),$\sigma$是长度比例参数。根据不同的$\sigma$,图结构也会不同,所以要谨慎地选择$\sigma$,实现最优的类别传播结果。此外,我们观察到在元学习框架中,调节该参数没有原则。
下面介绍一下Example-wise length-scale parameter 。
为了获得元学习中合适的图结构,我们提出了一个基于支持集和查询集的联合集合的图构造模块。这个模块由CNN$g_{\Phi}$构成。以${\rm x}_i \in {\cal S} \cup {\cal Q}$的特征图$f_{\varphi}({\rm x}_i)$为输入,产生example-wise length-scale parameter$\sigma_i = g_{\Phi}(f_{\varphi}({\rm x}_i))$。值得注意的是,scale参数是example-wise的,在元训练过程中学习得到,这使得它能够适用于不同的task,更加适合Few-shot Learning。得到example-wise的$\sigma _i$后,相似性函数定义如下:
其中$W \in R^{(N \times K+T) \times (N \times K+T) }$,对于$W$中每一行,我们只保留最大的$k$个值,构造一个$k$近邻图。然后在$W$上应用normalized graph Laplacians,即$S=D^{-1/2}WD^{-1/2}$,其中$D$为对角矩阵,对角线上的元素为$W$中第$i$行元素之和。
下面介绍一下Graph construction structure 。
所提的graph construction module如下图所示,它由两个卷积block和两个fc组成。两个卷积block中,卷积核的数目分别为64和1。为了提供example-wise scaling parameter,将第二个卷积层输出的特征图输入到两层fc中,两层fc的神经元数目分别为8和1。
下面介绍一下Graph construction in each episode。
在Few-shot 元学习中,我们遵循episode模式。也就是说,对于每一个episode的每一个task中,图都是单独构造出来的,如下图所示。在5-way,5-shot学习中,$N=5,K=5,T=75$,$W$的维度为$100 \times 100$,可以保证这个过程很高效。
label propagation
记$\cal F$为$(N \times K+T) \times N$(可以看做$N \times K+T$个样本属于$K$类的概率)非负矩阵集合,定义类别矩阵$Y \in {\cal F}$。若${\rm x}_i$来自支撑集且类标$y_i=j$,则$Y_{ij}=1$,否则$Y_{ij}=0$。标签传播算法从$Y$开始根据图结构和下面的函数进行迭代,产生联合集合${\cal S} \bigcup {\cal Q}$中未知标签实例的类别。
其中,$F_t \in {\cal F}$表示第$t$步预测出的类别,$S$表示标准化后的权重,$\alpha \in (0,1)$控制传播信息的总量。很容易得知序列$\{F_t \}$有一个闭合解(没有推导):
其中$I$为单位矩阵。直接利用这个结果去做标签传播,可以使得整个episode元学习过程更加高效。
现在讨论一下(4)式的时间复杂度,矩阵求逆时间复杂度为$O(n^3)$,但是在我们的设置中,$n=N\times K+T$非常小。
loss
我们计算${\cal S} \bigcup {\cal Q}$中预测得分$F^$和真实标记之间的交叉熵损失,从而实现端到端的训练。其中$F^$经过了softmax函数。
从这可以看出,$F^*$为矩阵。$\tilde y_i$表示第$i$个实例中最终预测出的标签。损失函数定义如下:
其中,$y_i$表示样本${\rm x}_i$的真实标记,${\mathbb I}(\cdot)$为标记函数。
实验结果
数据集
miniImageNet数据集有100类,每一个类有600个样本,每个样本的大小为$84 \times 84$,其中64个类别用于训练,16个类别用于验证,20个类别用于测试。
tiredImageNet,351类用于训练,97类用于验证,160类用于测试。每类的平均样本数量为1281。将该数据集每个图片均resize到$84 \times 84$尺寸。
实验设置
上面的超参数$k=20$,即20近邻图。标签传播中$\alpha = 0.99$。
我们发现,训练过程也有更多的样本结果会更好,因此对于1-shot test和5-shot test,在训练过程中,分别设置5-shot和10-shot。在所有实验中,查询数量设置为15,测试过程中的精度均随机产生600个task进行平均。
Few-shot Learning结果
缺点
- 构建图的时候,$d$可以改成CNN,使用欧式距离不一定适合网络。可以使用CNN让网络自己学习到适合自身的度量方式。
- $W$太大的时候计算复杂度太高。
启发
- 支撑集也是可以求损失的,这也是利用支撑集信息的一种方式
- 本文中可训练参数跟支撑集和查询集大小无关,因此可以支持训练集的支撑集数量和测试集中不同,查询集个数也可以不同
思考
什么是转导思想?
答:文章中提到利用转导(Transductive)的思想,拿到所有无标注数据,建立权重图,得到整个测试集的全部预测结果。所以在论文中,task中查询集的数量$T=75$,在一般的模型中,$T$通常比较小,例如等于1。这里拿出所有的无标注数据预测,也有点像半监督。